# Copyright (c) Meta Platforms, Inc. and affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import dataclasses
import importlib.util
from typing import Union

# import jax
import numpy as np
import torch

# from jax import dlpack as jax_dlpack, numpy as jnp
from tensordict import make_tensordict, TensorDictBase
from torch.utils import dlpack as torch_dlpack
from torchrl.data.tensor_specs import Composite, TensorSpec, Unbounded
from torchrl.data.utils import numpy_to_torch_dtype_dict

_has_jax = importlib.util.find_spec("jax") is not None


def _tree_reshape(x, batch_size: torch.Size):
    import jax

    shape, n = batch_size, 1
    return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)


def _tree_flatten(x, batch_size: torch.Size):
    import jax

    shape, n = (batch_size.numel(),), len(batch_size)
    return jax.tree_util.tree_map(lambda x: x.reshape(shape + x.shape[n:]), x)


_dtype_conversion = {
    np.dtype("uint16"): np.int16,
    np.dtype("uint32"): np.int32,
    np.dtype("uint64"): np.int64,
}


def _ndarray_to_tensor(
    value: Union["jnp.ndarray", np.ndarray]  # noqa: F821
) -> torch.Tensor:
    from jax import dlpack as jax_dlpack, numpy as jnp

    # JAX arrays generated by jax.vmap would have Numpy dtypes.
    if value.dtype in _dtype_conversion:
        value = value.view(_dtype_conversion[value.dtype])
    if isinstance(value, jnp.ndarray):
        dlpack_tensor = jax_dlpack.to_dlpack(value)
    elif isinstance(value, np.ndarray):
        dlpack_tensor = value.__dlpack__()
    else:
        raise NotImplementedError(f"unsupported data type {type(value)}")
    out = torch_dlpack.from_dlpack(dlpack_tensor)
    # dtype can be messed up by dlpack
    return out.to(numpy_to_torch_dtype_dict[value.dtype])


def _tensor_to_ndarray(value: torch.Tensor) -> "jnp.ndarray":  # noqa: F821
    from jax import dlpack as jax_dlpack

    return jax_dlpack.from_dlpack(torch_dlpack.to_dlpack(value.contiguous()))


def _get_object_fields(obj) -> dict:
    """Converts an object (named tuple or dataclass or dict) to a dict."""
    if isinstance(obj, tuple) and hasattr(obj, "_fields"):  # named tuple
        return dict(zip(obj._fields, obj))
    elif dataclasses.is_dataclass(obj):
        return {
            field.name: getattr(obj, field.name) for field in dataclasses.fields(obj)
        }
    elif isinstance(obj, dict):
        return obj
    elif obj is None:
        return {}
    else:
        raise NotImplementedError(f"unsupported data type {type(obj)}")


def _object_to_tensordict(obj, device, batch_size) -> TensorDictBase:
    """Converts a namedtuple or a dataclass to a TensorDict."""
    from jax import numpy as jnp

    t = {}
    _fields = _get_object_fields(obj)
    for name, value in _fields.items():
        if isinstance(value, (np.number, int, float)):
            t[name] = _ndarray_to_tensor(np.asarray([value])).to(device)
        elif isinstance(value, (jnp.ndarray, np.ndarray)):
            t[name] = _ndarray_to_tensor(value).to(device)
        else:
            nested = _object_to_tensordict(value, device, batch_size)
            if nested is not None:
                t[name] = nested
    if len(t):
        return make_tensordict(t, device=device, batch_size=batch_size)
    # discard empty tensordicts
    return None


def _tensordict_to_object(tensordict: TensorDictBase, object_example, batch_size=None):
    """Converts a TensorDict to a namedtuple or a dataclass."""
    from jax import dlpack as jax_dlpack, numpy as jnp

    if batch_size is None:
        batch_size = []
    t = {}
    _fields = _get_object_fields(object_example)
    for name, example in _fields.items():
        value = tensordict.get(name, None)
        if isinstance(value, TensorDictBase):
            t[name] = _tensordict_to_object(value, example, batch_size=batch_size)
        elif value is None:
            if isinstance(example, dict):
                t[name] = _tensordict_to_object({}, example, batch_size=batch_size)
            else:
                t[name] = None
        else:
            if value.dtype is torch.bool:
                value = value.to(torch.uint8)
            shape = value.shape
            # We need to flatten to fix https://github.com/pytorch/rl/issues/2184
            value = value.contiguous()
            value = value.detach()
            if value.ndim > 1:
                value = value.flatten().clone()
            else:
                # Need this because otherwise an exception is raised
                #  ValueError: INTERNAL: Address of buffer 1 must be a multiple of 10, but was 0x7efccec00824
                value = value.clone()
            value = jax_dlpack.from_dlpack(value)
            if shape.numel() == 1 and not value.shape:
                while value.shape != shape:
                    value = jnp.expand_dims(value, 0)
                if value.dtype != example.dtype:
                    t[name] = value.view(example.dtype)
                else:
                    t[name] = value
            else:
                value = jnp.reshape(value, tuple(shape))
                t[name] = value.view(example.dtype).reshape(
                    (*batch_size, *example.shape)
                )
    return type(object_example)(**t)


def _extract_spec(data: Union[torch.Tensor, TensorDictBase], key=None) -> TensorSpec:
    if isinstance(data, torch.Tensor):
        shape = data.shape
        if key in ("reward", "done"):
            shape = (*shape, 1)
        if data.dtype in (torch.float, torch.double, torch.half):
            return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
        else:
            return Unbounded(shape=shape, dtype=data.dtype, device=data.device)
    elif isinstance(data, TensorDictBase):
        return Composite(
            {key: _extract_spec(value, key=key) for key, value in data.items()}
        )
    else:
        raise TypeError(f"Unsupported data type {type(data)}")
